# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Student Trainer classe."""

from flax import jax_utils
from gift.pipelines import end2end
from gift.pipelines import multi_env_end2end
from gift.pipelines import pipeline_utils


class StudentEnd2EndTrainer(end2end.End2end):
  """Student end2end training pipeline."""

  def __init__(self, model_cls, task, hparams, experiment_dir,
               tb_summary_writer, rng):
    super().__init__(model_cls, task, hparams, experiment_dir,
                     tb_summary_writer, rng)

    # Check compatibility of hparams:

    # If teacher_data_augmentations is set, we should also set the
    # the teacher_inputs_key to teacher_inputs (otherwise setting teacher_
    # data_augmentations has no effect.
    assert not ((hparams.teacher_data_augmentations is not None) ^
                (hparams.teacher_inputs_key is not None))

    # Load the teacher model and reset pseudo label generator of the dataset
    # to use the teacher's predictions.
    self.setup_teacher(rng)

  def setup_teacher(self, rng):
    """Loads the teacher model, and restarts the data iterators to use it.

    Args:
      rng: foat; JAX PRNG key.

    Returns:
      Nothing.
    """
    (teacher_config, teacher_ckpnt,
     teacher_ckpnt_step) = self.load_teacher_info(self.hparams)

    # Create and loads the teacher model from the checkpoint path.
    teacher_train_state = pipeline_utils.load_model(
        rng=rng,
        model_config=teacher_config,
        model_ckpt=teacher_ckpnt,
        task=self.task,
        checkpoint_step=teacher_ckpnt_step)
    # Replicate the optimzier, state, and rng.
    self.teacher_train_state = jax_utils.replicate(teacher_train_state)

    # Build a pseudo label generator to be passed to the dataset class,
    # that use the teacher model to produce pseudo labels.
    pseudo_label_generator = pipeline_utils.get_pseudo_label_generator(
        train_state=self.teacher_train_state,
        train=False,
        input_key=self.hparams.get('teacher_inputs_key', 'inputs'),
        temp=self.hparams.get('label_temp', 1.0),
        confidence_quantile_threshold=self.hparams.get(
            'confidence_quantile_threshold', 0.1),
        self_supervised_label_transformation=self.hparams.get(
            'self_supervised_label_transformation', 'soft'))

    # Reload the training data iterator with a new process_data stage where
    # ground truth labels are replaced with pseudo labels generated by the
    # teacher model.
    self.task.dataset.reset_pseudo_label_generator(pseudo_label_generator)


class StudentMultiEnvEnd2EndTrainer(multi_env_end2end.MultiEnvEnd2End):
  """Student end2end training pipeline."""

  def __init__(self, model_cls, task, hparams, experiment_dir,
               tb_summary_writer, rng):
    super().__init__(model_cls, task, hparams, experiment_dir,
                     tb_summary_writer, rng)

    self.setup_teacher(rng)

  def setup_teacher(self, rng):
    """Loads the teacher model, and restarts the data iterators to use it.

    Args:
      rng: foat; JAX PRNG key.

    Returns:
      Nothing.
    """
    (teacher_config, teacher_ckpnt,
     teacher_ckpnt_step) = self.load_teacher_info(self.hparams)
    # Create and loads the teacher model from the checkpoint path.
    teacher_train_state = pipeline_utils.load_model(
        rng=rng,
        model_config=teacher_config,
        model_ckpt=teacher_ckpnt,
        task=self.task,
        checkpoint_step=teacher_ckpnt_step)
    self.teacher_train_state = jax_utils.replicate(teacher_train_state)

    # Build a pseudo label generator to be passed to the dataset class,
    # that use the teacher model to produce pseudo labels.
    pseudo_label_generator = pipeline_utils.get_pseudo_label_generator(
        train_state=self.teacher_train_state,
        train=False,
        input_key=self.hparams.get('teacher_inputs_key', 'inputs'),
        temp=self.hparams.get('label_temp', 1.0),
        confidence_quantile_threshold=self.hparams.get(
            'confidence_quantile_threshold', 0.1),
        self_supervised_label_transformation=self.hparams.get(
            'self_supervised_label_transformation', 'soft'))

    # Reload the training data iterators with a new process_data stage where
    # ground truth labels are replaced with pseudo labels generated by the
    # teacher model.
    self.task.dataset.reset_pseudo_label_generator(pseudo_label_generator)
